{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5eeje4O8fviH",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "# Parking with Hindsight Experience Replay\n",
        "\n",
        "##  Warming up\n",
        "We start with a few useful installs and imports:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bzMSuJEOfviP",
        "pycharm": {
          "is_executing": false,
          "name": "#%%\n"
        },
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Install environment and agent\n",
        "!pip install highway-env\n",
        "# TODO: we use the bleeding edge version because the current stable version does not support the latest gym>=0.21 versions. Revert back to stable at the next SB3 release.\n",
        "!pip install git+https://github.com/DLR-RM/stable-baselines3\n",
        "\n",
        "# Environment\n",
        "import gymnasium as gym\n",
        "import highway_env\n",
        "\n",
        "# Agent\n",
        "from stable_baselines3 import HerReplayBuffer, SAC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "so7yH4ucyB-3"
      },
      "outputs": [],
      "source": [
        "#@title Import helpers for visualization of episodes\n",
        "import sys\n",
        "from tqdm.notebook import trange\n",
        "!pip install tensorboardx gym pyvirtualdisplay\n",
        "!apt-get install -y xvfb python-opengl ffmpeg\n",
        "!git clone https://github.com/eleurent/highway-env.git 2> /dev/null\n",
        "sys.path.insert(0, '/content/highway-env/scripts/')\n",
        "from utils import record_videos, show_videos"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Tensorboard - click the refresh button once training is running\n",
        "\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir logs"
      ],
      "metadata": {
        "cellView": "form",
        "id": "frgfxpsP3fFn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y5TOvonYqP-g",
        "pycharm": {
          "name": "#%% \n"
        },
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Training\n",
        "\n",
        "LEARNING_STEPS = 5e4 # @param {type: \"number\"}\n",
        "\n",
        "env = gym.make('parking-v0')\n",
        "her_kwargs = dict(n_sampled_goal=4, goal_selection_strategy='future')\n",
        "model = SAC('MultiInputPolicy', env, replay_buffer_class=HerReplayBuffer,\n",
        "            replay_buffer_kwargs=her_kwargs, verbose=1, \n",
        "            tensorboard_log=\"logs\", \n",
        "            buffer_size=int(1e6),\n",
        "            learning_rate=1e-3,\n",
        "            gamma=0.95, batch_size=1024, tau=0.05,\n",
        "            policy_kwargs=dict(net_arch=[512, 512, 512]))\n",
        "\n",
        "model.learn(int(LEARNING_STEPS))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xOcOP7Of18T2",
        "pycharm": {
          "name": "#%%\n"
        },
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Visualize a few episodes\n",
        "\n",
        "N_EPISODES = 10  # @param {type: \"integer\"}\n",
        "\n",
        "env = gym.make('parking-v0', render_mode='rgb_array')\n",
        "env = record_videos(env)\n",
        "for episode in trange(N_EPISODES, desc=\"Test episodes\"):\n",
        "    obs, info = env.reset()\n",
        "    done = truncated = False\n",
        "    while not (done or truncated):\n",
        "        action, _ = model.predict(obs, deterministic=True)\n",
        "        obs, reward, done, truncated, info = env.step(action)\n",
        "env.close()\n",
        "show_videos()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "parking_her.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.1 (tags/v3.10.1:2cd268a, Dec  6 2021, 19:10:37) [MSC v.1929 64 bit (AMD64)]"
    },
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "metadata": {
          "collapsed": false
        },
        "source": []
      }
    },
    "vscode": {
      "interpreter": {
        "hash": "369f2c481f4da34e4445cda3fffd2e751bd1c4d706f27375911949ba6bb62e1c"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}